import pdb

from torch_geometric.datasets import Planetoid
import torch_geometric.utils as utils
from train import *

paths = ['/tmp/PubMed', '/tmp/Cora', '/tmp/Citeseer']
names = ['PubMed', 'Cora', 'Citeseer']
dataset = names[0]
graph = Planetoid(root=f'/tmp/{dataset}', name=dataset)
data = graph[0]
print(data)
print(f'#class: {graph.num_classes}, #feature_dim: {graph.num_node_features}')

device = torch.device('cuda:7' if torch.cuda.is_available() else 'cpu')

# adds_loop, _ = utils.add_self_loops(data.edge_index)
# Adj_loop = utils.to_dense_adj(adds_loop)
# Adj_loop = torch.squeeze(Adj_loop, dim=0).to(device)
Adj = utils.to_dense_adj(data.edge_index).squeeze(dim=0).to(device)

if dataset in ['PubMed', 'Cora', 'Citeseer']:
    epochs = 100
    d = 10
    mu = 0.03 * np.ones(d)
    mus = np.array([i * np.ones(d) for i in np.array([0.1, 0.2, 0.3, 0.4, 0.5])])
else:
    raise NotImplementedError

repeat = 2

# train_gaussian(mus, d, repeat, data, Adj, device, epochs, graph.num_classes, dataset)
train_laplacian(mus, d, repeat, data, Adj, device, epochs, graph.num_classes, dataset)